# QC-Bench True/False Questions Evaluation Script
# Files needed to run:
# - qc_tf.json (True/False questions from the QC-Bench benchmark)

import os
import re
import time
import json
import random
from collections import defaultdict

import requests
import openai
from anthropic import Anthropic
from google import genai
from google.genai import types

# Set random seed for reproducibility
random.seed(42)

# Set your API keys here
os.environ["OPENAI_API_KEY"] = "Key1"
os.environ["ANTHROPIC_API_KEY"] = "Key1"
os.environ["GROQ_API_KEY"] = "Key1"
os.environ["GOOGLE_API_KEY"] = "Key1"

# Initialize clients
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)
openai.api_key = os.environ["OPENAI_API_KEY"]
anthropic_client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])

# -------------------------------
# USER CONFIGURATION SECTION
# -------------------------------

# File with true/false questions to benchmark
tf_file = "qc_tf.json"

miss_threshold = 12

# -------------------------------
# MODEL CONFIGURATION
# -------------------------------

model_configs = {
    "openai": {"models": {
        "gpt-4-1106-preview": "GPT-4.1",
        "gpt-4-turbo-2024-04-09": "GPT-4.1 mini",
        "gpt-3.5-turbo-0125": "GPT-4.1 nano",
        "gpt-4o": "GPT-4o",
        "gpt-4o-mini-2024-07-18": "GPT-4o-mini",
    }},
    "anthropic": {"models": {
        "claude-3-7-sonnet-20250219": "Claude 3.7 Sonnet",
        "claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet",
        "claude-3-haiku-20240307": "Claude 3.5 Haiku",
    }},
    "groq": {"models": {
        "llama3-70b-8192": "LLaMA 3 70B",
        "llama-3.3-70b-versatile": "LLaMA-3.3-70B-Versatile",
        "gemma2-9b-it": "Gemma 9b",
    }},
    "gemini": {"models": {
        "gemini-1.5-pro-latest": "Gemini 1.5 Pro",
        "gemini-2.0-flash": "Gemini 2.0 Flash",
    }},
}

# This list controls both model inclusion and order
model_run_order = [
    "GPT-4.1",
    "GPT-4.1 mini",
    "GPT-4.1 nano",
    "GPT-4o",
    "GPT-4o-mini",
    "Claude 3.7 Sonnet",
    "Claude 3.5 Sonnet",
    "Claude 3.5 Haiku",
    "LLaMA 3 70B",
    "LLaMA-3.3-70B-Versatile",
    "Gemma 9b",
    "Gemini 1.5 Pro",
    "Gemini 2.0 Flash"
]

# Flatten for fast lookup: label_id -> (provider, model_id)
model_lookup = {
    label_id: (provider, model_id)
    for provider, config in model_configs.items()
    for model_id, label_id in config["models"].items()
}


# -------------------------------
# BENCHMARK EXECUTION
# -------------------------------

def benchmark_tf_questions():
    """Benchmark true/false quantum computing questions across all models."""
    # Load the true/false questions
    with open(tf_file, "r", encoding="utf-8") as f:
        all_questions = json.load(f)

    print(f"\n=== Benchmarking {len(all_questions)} True/False Quantum Computing Questions ===")

    # Initialize results storage
    all_results = {}
    missed_by = defaultdict(list)

    # Benchmark each model
    for label_id in model_run_order:
        if label_id not in model_lookup:
            continue

        provider, model_id = model_lookup[label_id]
        print(f"\nBenchmarking {label_id}...")

        correct = 0

        for i, q in enumerate(all_questions):
            # Create the prompt for true/false question
            prompt = f"""You are an expert in quantum computing. For the following statement, determine if it is TRUE or FALSE.

Statement: {q['question']}

Respond with only one word: TRUE or FALSE.

Answer:"""

            try:
                # Call the appropriate API based on the provider
                if provider == "openai":
                    resp = openai.ChatCompletion.create(
                        model=model_id,
                        messages=[{"role": "user", "content": prompt}],
                        temperature=0.0,
                        max_tokens=5,
                    )
                    content = resp.choices[0].message.content.strip().upper()

                elif provider == "anthropic":
                    resp = anthropic_client.messages.create(
                        model=model_id,
                        messages=[{"role": "user", "content": prompt}],
                        temperature=0.0,
                        max_tokens=5,
                    )
                    content = resp.content[0].text.strip().upper()

                elif provider == "groq":
                    resp = requests.post(
                        "https://api.groq.com/openai/v1/chat/completions",
                        headers={
                            "Authorization": f"Bearer {os.environ['GROQ_API_KEY']}",
                            "Content-Type": "application/json",
                        },
                        json={
                            "model": model_id,
                            "messages": [{"role": "user", "content": prompt}],
                            "temperature": 0.0,
                            "max_tokens": 5,
                        },
                    ).json()
                    if "choices" not in resp:
                        continue
                    content = resp["choices"][0]["message"]["content"].strip().upper()
                    time.sleep(0.6)  # Rate limiting for Groq

                elif provider == "gemini":
                    resp = client.models.generate_content(
                        model=model_id,
                        contents=prompt,
                        config=types.GenerateContentConfig(
                            temperature=0.0,
                            max_output_tokens=5,
                        ),
                    )
                    content = resp.text.strip().upper()

                else:
                    continue

                # Extract the response (TRUE or FALSE)
                pred_match = re.search(r"TRUE|FALSE", content)
                predicted = pred_match.group(0) if pred_match else "UNKNOWN"

                # Convert TRUE/FALSE to the format in solution (True/False)
                if predicted == "TRUE":
                    predicted = "True"
                elif predicted == "FALSE":
                    predicted = "False"
                else:
                    predicted = "Unknown"

                # Check if the prediction is correct
                if predicted == q["solution"]:
                    correct += 1
                else:
                    missed_by[i].append((label_id, predicted))

                # Print progress every 10 questions
                if (i + 1) % 10 == 0:
                    print(f"  Processed {i + 1}/{len(all_questions)} questions...")

                # Respect rate limits
                time.sleep(0.1)

            except Exception as e:
                print(f"Error with {provider.title()} ({label_id}) on question {i + 1}: {str(e)}")
                time.sleep(2)  # Back off in case of rate limiting

        # Calculate accuracy
        acc = correct / len(all_questions)
        all_results[label_id] = {
            "overall": acc
        }

        # Print results
        print(f"{label_id} overall accuracy: {acc:.2%}")

    # Save results to JSON
    with open("tf_results.json", "w", encoding="utf-8") as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)

    # -------------------------------
    # FINAL REPORT
    # -------------------------------

    print("\n=== FINAL ACCURACY REPORT (TRUE/FALSE) ===")
    print("\nOverall Accuracy:")
    for model in model_run_order:
        if model in all_results:
            print(f"{model}: {all_results[model]['overall']:.2%}")

    print(f"\n=== Questions Missed by ≥ {miss_threshold} Models ===")
    for q_idx, models_missed in missed_by.items():
        if len(models_missed) >= miss_threshold:
            question_text = all_questions[q_idx]["question"]
            correct_answer = all_questions[q_idx]["solution"]
            print(f"\nQ{q_idx + 1} missed by {len(models_missed)} models: {question_text[:100]}...")
            print(f"Correct Answer: {correct_answer}")
            model_list = ", ".join([f"{m}({ans})" for m, ans in models_missed])
            print(f"Missed by: {model_list}")

    # Additional analysis: True vs False accuracy
    true_correct = defaultdict(int)
    true_total = sum(1 for q in all_questions if q["solution"] == "True")
    false_correct = defaultdict(int)
    false_total = sum(1 for q in all_questions if q["solution"] == "False")

    for i, q in enumerate(all_questions):
        is_true = q["solution"] == "True"

        for label_id in model_run_order:
            if label_id not in model_lookup:
                continue

            # Count models that got it right (not in missed_by)
            if i not in missed_by or not any(m[0] == label_id for m in missed_by[i]):
                if is_true:
                    true_correct[label_id] += 1
                else:
                    false_correct[label_id] += 1

    print("\n=== True vs False Statement Accuracy ===")
    for model in model_run_order:
        if model in all_results:
            true_acc = true_correct[model] / true_total if true_total > 0 else 0
            false_acc = false_correct[model] / false_total if false_total > 0 else 0
            print(f"{model}: True: {true_acc:.2%}, False: {false_acc:.2%}")


if __name__ == "__main__":
    benchmark_tf_questions()